{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 13,
   "source": [
    "import multiprocessing\n",
    "import tensorflow as tf\n",
    "import librosa\n",
    "import numpy as np\n",
    "from jiwer import wer"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "This is the audio file we are going to transcribe, as well as the ground truth transcription"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 92,
   "source": [
    "audio_file = 'demo_input/84-121550-0000.flac'\n",
    "transcript = 'BUT WITH FULL RAVISHMENT THE HOURS OF PRIME SINGING RECEIVED THEY IN THE MIDST OF LEAVES THAT EVER BORE A BURDEN TO THEIR RHYMES'"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "We first convert the transcript into integers, as well as defining a reverse mapping for decoding the final output."
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 101,
   "source": [
    "alphabet = \"abcdefghijklmnopqrstuvwxyz' @\"\n",
    "alphabet_dict = {c: ind for (ind, c) in enumerate(alphabet)}\n",
    "index_dict = {ind: c for (ind, c) in enumerate(alphabet)}\n",
    "transcript_ints = [alphabet_dict[letter] for letter in transcript.lower()]\n",
    "print(transcript_ints)"
   ],
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "[1, 20, 19, 27, 22, 8, 19, 7, 27, 5, 20, 11, 11, 27, 17, 0, 21, 8, 18, 7, 12, 4, 13, 19, 27, 19, 7, 4, 27, 7, 14, 20, 17, 18, 27, 14, 5, 27, 15, 17, 8, 12, 4, 27, 18, 8, 13, 6, 8, 13, 6, 27, 17, 4, 2, 4, 8, 21, 4, 3, 27, 19, 7, 4, 24, 27, 8, 13, 27, 19, 7, 4, 27, 12, 8, 3, 18, 19, 27, 14, 5, 27, 11, 4, 0, 21, 4, 18, 27, 19, 7, 0, 19, 27, 4, 21, 4, 17, 27, 1, 14, 17, 4, 27, 0, 27, 1, 20, 17, 3, 4, 13, 27, 19, 14, 27, 19, 7, 4, 8, 17, 27, 17, 7, 24, 12, 4, 18]\n"
     ]
    }
   ],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "We then load the audio file and convert it to MFCCs (with an extra batch dimension)."
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 70,
   "source": [
    "def normalize(values):\n",
    "    \"\"\"\n",
    "    Normalize values to mean 0 and std 1\n",
    "    \"\"\"\n",
    "    return (values - np.mean(values)) / np.std(values)\n",
    "\n",
    "def transform_audio_to_mfcc(audio_file, transcript, n_mfcc=13, n_fft=512, hop_length=160):\n",
    "    audio_data, sample_rate = librosa.load(audio_file, sr=16000)\n",
    "\n",
    "    mfcc = librosa.feature.mfcc(audio_data, sr=sample_rate, n_mfcc=n_mfcc, n_fft=n_fft, hop_length=hop_length)\n",
    "\n",
    "    # add derivatives and normalize\n",
    "    mfcc_delta = librosa.feature.delta(mfcc)\n",
    "    mfcc_delta2 = librosa.feature.delta(mfcc, order=2)\n",
    "    mfcc = np.concatenate((normalize(mfcc), normalize(mfcc_delta), normalize(mfcc_delta2)), axis=0)\n",
    "\n",
    "    seq_length = mfcc.shape[1] // 2\n",
    "\n",
    "    sequences = np.concatenate([[seq_length], transcript]).astype(np.int32)\n",
    "    sequences = np.expand_dims(sequences, 0)\n",
    "    mfcc_out = mfcc.T.astype(np.float32)\n",
    "    mfcc_out = np.expand_dims(mfcc_out, 0)\n",
    "\n",
    "    return mfcc_out, sequences"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "source": [
    "def log(std):\n",
    "    \"\"\"Log the given string to the standard output.\"\"\"\n",
    "    print(\"******* {}\".format(std), flush=True)"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "We use the ctc decoder to decode the output of the network"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 99,
   "source": [
    "def ctc_preparation(tensor, y_predict):\n",
    "    if len(y_predict.shape) == 4:\n",
    "        y_predict = tf.squeeze(y_predict, axis=1)\n",
    "    y_predict = tf.transpose(y_predict, (1, 0, 2))\n",
    "    sequence_lengths, labels = tensor[:, 0], tensor[:, 1:]\n",
    "    idx = tf.where(tf.not_equal(labels, 28))\n",
    "    sparse_labels = tf.SparseTensor(\n",
    "        idx, tf.gather_nd(labels, idx), tf.shape(labels, out_type=tf.int64)\n",
    "    )\n",
    "    return sparse_labels, sequence_lengths, y_predict\n",
    "\n",
    "def ctc_ler(y_true, y_predict):\n",
    "    sparse_labels, logit_length, y_predict = ctc_preparation(y_true, y_predict)\n",
    "    decoded, log_probabilities = tf.nn.ctc_greedy_decoder(\n",
    "        y_predict, tf.cast(logit_length, tf.int32), merge_repeated=True\n",
    "    )\n",
    "    return tf.reduce_mean(\n",
    "        tf.edit_distance(\n",
    "            tf.cast(decoded[0], tf.int32), tf.cast(sparse_labels, tf.int32)\n",
    "        ).numpy()\n",
    "    ), tf.sparse.to_dense(decoded[0]).numpy()\n",
    "\n",
    "def trans_int_to_string(trans_int):\n",
    "    #create dictionary int -> string (0 -> a 1 -> b)\n",
    "    string = \"\"\n",
    "    alphabet = \"abcdefghijklmnopqrstuvwxyz' @\"\n",
    "    alphabet_dict = {}\n",
    "    count = 0\n",
    "    for x in alphabet:\n",
    "        alphabet_dict[count] = x\n",
    "        count += 1\n",
    "    for letter in trans_int:\n",
    "        letter_np = np.array(letter).item(0)\n",
    "        if letter_np != 28:\n",
    "            string += alphabet_dict[letter_np]\n",
    "    return string\n",
    "\n",
    "def ctc_wer(y_true, y_predict):\n",
    "    sparse_labels, logit_length, y_predict = ctc_preparation(y_true, y_predict)\n",
    "    decoded, log_probabilities = tf.nn.ctc_greedy_decoder(\n",
    "            y_predict, tf.cast(logit_length, tf.int32), merge_repeated=True\n",
    "    )\n",
    "    true_sentence = tf.cast(sparse_labels.values, tf.int32)\n",
    "    return wer(str(trans_int_to_string(decoded[0].values)), str(trans_int_to_string(true_sentence)))"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "The TFLite file requires inputs of size 296, so we apply a window to the input"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 94,
   "source": [
    "def evaluate_tflite(tflite_path, input_window_length =  296):\n",
    "    \"\"\"Evaluates tflite (fp32, int8).\"\"\"\n",
    "    results = []\n",
    "    data, label = transform_audio_to_mfcc(audio_file, transcript_ints)\n",
    "\n",
    "    interpreter = tf.lite.Interpreter(model_path=tflite_path, num_threads=multiprocessing.cpu_count())\n",
    "    interpreter.allocate_tensors()\n",
    "    input_chunk = interpreter.get_input_details()[0]\n",
    "    output_details = interpreter.get_output_details()[0]\n",
    "\n",
    "    input_shape = input_chunk[\"shape\"]\n",
    "    log(\"eval_model() - input_shape: {}\".format(input_shape))\n",
    "    input_dtype = input_chunk[\"dtype\"]\n",
    "    output_dtype = output_details[\"dtype\"]\n",
    "\n",
    "    # Check if the input/output type is quantized,\n",
    "    # set scale and zero-point accordingly\n",
    "    if input_dtype != tf.float32:\n",
    "        input_scale, input_zero_point = input_chunk[\"quantization\"]\n",
    "    else:\n",
    "        input_scale, input_zero_point = 1, 0\n",
    "\n",
    "    if output_dtype != tf.float32:\n",
    "        output_scale, output_zero_point = output_details[\"quantization\"]\n",
    "    else:\n",
    "        output_scale, output_zero_point = 1, 0\n",
    "\n",
    "\n",
    "    data = data / input_scale + input_zero_point\n",
    "    # Round the data up if dtype is int8, uint8 or int16\n",
    "    if input_dtype is not np.float32:\n",
    "        data = np.round(data)\n",
    "\n",
    "    while data.shape[1] < input_window_length:\n",
    "        data = np.append(data, data[:, -2:-1, :], axis=1)\n",
    "    # Zero-pad any odd-length inputs\n",
    "    if data.shape[1] % 2 == 1:\n",
    "        # log('Input length is odd, zero-padding to even (first layer has stride 2)')\n",
    "        data = np.concatenate([data, np.zeros((1, 1, data.shape[2]), dtype=input_dtype)], axis=1)\n",
    "\n",
    "    context = 24 + 2 * (7 * 3 + 16)  # = 98 - theoretical max receptive field on each side\n",
    "    size = input_chunk['shape'][1]\n",
    "    inner = size - 2 * context\n",
    "    data_end = data.shape[1]\n",
    "\n",
    "    # Initialize variables for the sliding window loop\n",
    "    data_pos = 0\n",
    "    outputs = []\n",
    "\n",
    "    while data_pos < data_end:\n",
    "        if data_pos == 0:\n",
    "            # Align inputs from the first window to the start of the data and include the intial context in the output\n",
    "            start = data_pos\n",
    "            end = start + size\n",
    "            y_start = 0\n",
    "            y_end = y_start + (size - context) // 2\n",
    "            data_pos = end - context\n",
    "        elif data_pos + inner + context >= data_end:\n",
    "            # Shift left to align final window to the end of the data and include the final context in the output\n",
    "            shift = (data_pos + inner + context) - data_end\n",
    "            start = data_pos - context - shift\n",
    "            end = start + size\n",
    "            assert start >= 0\n",
    "            y_start = (shift + context) // 2  # Will be even because we assert it above\n",
    "            y_end = size // 2\n",
    "            data_pos = data_end\n",
    "        else:\n",
    "            # Capture only the inner region from mid-input inferences, excluding output from both context regions\n",
    "            start = data_pos - context\n",
    "            end = start + size\n",
    "            y_start = context // 2\n",
    "            y_end = y_start + inner // 2\n",
    "            data_pos = end - context\n",
    "\n",
    "        interpreter.set_tensor(input_chunk[\"index\"], tf.cast(data[:, start:end, :], input_dtype))\n",
    "        interpreter.invoke()\n",
    "        cur_output_data = interpreter.get_tensor(output_details[\"index\"])[:, :, y_start:y_end, :]\n",
    "        cur_output_data = output_scale * (\n",
    "                cur_output_data.astype(np.float32) - output_zero_point\n",
    "        )\n",
    "        outputs.append(cur_output_data)\n",
    "\n",
    "    complete = np.concatenate(outputs, axis=2)\n",
    "    LER, output = ctc_ler(label, complete)\n",
    "    WER = ctc_wer(label, complete)\n",
    "    return output, LER , WER\n"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 107,
   "source": [
    "wav2letter_tflite_path = \"tflite_int8/tiny_wav2letter_int8.tflite\"\n",
    "output, LER , WER = evaluate_tflite(wav2letter_tflite_path)\n",
    "\n",
    "decoded_output = [index_dict[value] for value in output[0]]\n",
    "log(f'Transcribed File: {\"\".join(decoded_output)}')\n",
    "log(f'Letter Error Rate is {LER}')\n",
    "log(f'Word Error Rate is {WER}')"
   ],
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "******* eval_model() - input_shape: [  1 296  39]\n",
      "******* Input length is odd, zero-padding to even (first layer has stride 2)\n",
      "******* Transcribed File: but with full ravishment the hours of prime singing received they in the midst of leaves that everborea burden to their rimes\n",
      "******* Letter Error Rate is 0.03125\n",
      "******* Word Error Rate is 1.05\n"
     ]
    }
   ],
   "metadata": {}
  }
 ],
 "metadata": {
  "orig_nbformat": 4,
  "language_info": {
   "name": "python",
   "version": "3.8.2",
   "mimetype": "text/x-python",
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "pygments_lexer": "ipython3",
   "nbconvert_exporter": "python",
   "file_extension": ".py"
  },
  "kernelspec": {
   "name": "python3",
   "display_name": "Python 3.8.2 64-bit ('env': venv)"
  },
  "interpreter": {
   "hash": "4b529a2edd0e262cfd8353ba70b138cbba10314325c544d99b9316c477c7841b"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
